{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classifying handwritten digits\n",
    "\n",
    "This notebook shows how ``giotto-tda`` can be used to generate topological features for image classification. We'll be using the famous MNIST dataset, which contains images of handwritten digits and is a standard benchmark for testing new classification algorithms.\n",
    "\n",
    "<div style=\"text-align: center\">\n",
    "<img src='images/mnist.png'>\n",
    "<p style=\"text-align: center;\"> <b>Figure 1:</b> A few digits from the MNIST dataset. Figure reference: <a href=\"https://en.wikipedia.org/wiki/MNIST_database\">en.wikipedia.org/wiki/MNIST_database</a>. </p>\n",
    "</div>\n",
    "\n",
    "If you are looking at a static version of this notebook and would like to run its contents, head over to [GitHub](https://github.com/giotto-ai/giotto-tda/blob/master/examples/MNIST_classification.ipynb).\n",
    "\n",
    "\n",
    "## Useful references\n",
    "\n",
    "* [_A Topological \"Reading\" Lesson: Classification of MNIST using TDA_](https://arxiv.org/abs/1910.08345) by Adélie Garin and Guillaume Tauzin\n",
    "* [_The MNIST Database of Handwritten Digits_](http://yann.lecun.com/exdb/mnist/) by Yann LeCun, Corinna Cortes, and Christopher J.C. Burges\n",
    "\n",
    "**License: AGPLv3**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the MNIST dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get started, let's fetch the MNIST dataset using one of ``scikit-learn``'s helper functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_openml\n",
    "\n",
    "X, y = fetch_openml(\"mnist_784\", version=1, return_X_y=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By looking at the shapes of these arrays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"X shape: {X.shape}, y shape: {y.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "we see that there are 70,000 images, where each image has 784 features that represent pixel intensity. Let's reshape the feature vector to a 28x28 array and visualise one of the \"8\" digits using ``giotto-tda``'s plotting API:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from gtda.plotting import plot_heatmap\n",
    "\n",
    "im8_idx = np.flatnonzero(y == \"8\")[0]\n",
    "img8 = X[im8_idx].reshape(28, 28)\n",
    "plot_heatmap(img8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create train and test sets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this example, we will work with a small subset of images – to run a full-blown analysis simply change the values of ``train_size`` and ``test_size`` below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_size, test_size = 60, 10\n",
    "\n",
    "# Reshape to (n_samples, n_pixels_x, n_pixels_y)\n",
    "X = X.reshape((-1, 28, 28))\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    X, y, train_size=train_size, test_size=test_size, stratify=y, random_state=666\n",
    ")\n",
    "\n",
    "print(f\"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}\")\n",
    "print(f\"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## From pixels to topological features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As shown in the figure below, several steps are required to extract topological features from an image. Since our images are made of pixels, it is convenient to use filtrations of [_cubical complexes_](https://giotto-ai.github.io/gtda-docs/latest/theory/glossary.html#cubical-complex) instead of simplicial ones. Let's go through each of these steps for a single \"8\" digit using ``giotto-tda``!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"text-align: center\">\n",
    "<img src='images/example_pipeline_images.png' width='600'>\n",
    "<p style=\"text-align: center;\"> <b>Figure 2:</b> An example of a topological feature extraction pipeline. Figure reference: <a href=\"https://arxiv.org/abs/1910.08345\">arXiv:1910.08345</a>. </p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Binarize the image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In ``giotto-tda``, filtrations of cubical complexes are built from _binary images_ consisting of only black and white pixels. We can convert our greyscale image to binary by applying a threshold on each pixel value via the ``Binarizer`` transformer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.images import Binarizer\n",
    "\n",
    "# Pick out index of first 8 image\n",
    "im8_idx = np.flatnonzero(y_train == \"8\")[0]\n",
    "# Reshape to (n_samples, n_pixels_x, n_pixels_y) format\n",
    "im8 = X_train[im8_idx][None, :, :]\n",
    "\n",
    "binarizer = Binarizer(threshold=0.4)\n",
    "im8_binarized = binarizer.fit_transform(im8)\n",
    "\n",
    "binarizer.plot(im8_binarized)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### From binary image to filtration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have a binary image $\\mathcal{B}$ of our \"8\" digit, we can build a wide variety of different filtrations – see the ``giotto-tda`` [docs](https://giotto-ai.github.io/gtda-docs/latest/modules/images.html#filtrations) for a full list. For our example, we'll use the _radial filtration_ $\\mathcal{R}$, which assigns to each pixel $p$ a value corresponding to its distance from a predefined center $c$ of the image\n",
    "\n",
    "$$ \\mathcal{R}(p) = \\left\\{ \\begin{array}{cl} \n",
    "\\lVert c - p \\rVert_2 &\\mbox{if } \\mathcal{B}(p)=1 \\\\ \n",
    "\\mathcal{R}_\\infty &\\mbox{if } \\mathcal{B}(p)=0 \n",
    "\\end{array} \\right. $$\n",
    "\n",
    "where $\\mathcal{R}_\\infty$ is the distance of the pixel that is furthest from $c$. To reproduce the filtered image from the MNIST [article](https://arxiv.org/abs/1910.08345), we'll pick $c = (20,6)$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.images import RadialFiltration\n",
    "\n",
    "radial_filtration = RadialFiltration(center=np.array([20, 6]))\n",
    "im8_filtration = radial_filtration.fit_transform(im8_binarized)\n",
    "\n",
    "radial_filtration.plot(im8_filtration, colorscale=\"jet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see from the resulting plot that we've effectively transformed our binary image into a greyscale one, where the pixel values increase as we move from the upper-right to bottom-left of the image! These pixel values can be used to define a filtration of cubical complexes $\\{K_i\\}_{i\\in \\mathrm{Im}(I)}$, where $K_i$ contains all pixels with value less than the $i$th smallest pixel value in the greyscale image. In other words, $K_i$ is the $i$th sublevel set of the image's cubical complex $K$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### From filtration to persistence diagram"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given a greyscale filtration it is straightforward to calculate the corresponding persistence diagram. In ``giotto-tda`` we make use of the ``CubicalPersistence`` transformer which is the cubical analogue to simplicial transformers like ``VietorisRipsPersistence``:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.homology import CubicalPersistence\n",
    "\n",
    "cubical_persistence = CubicalPersistence(n_jobs=-1)\n",
    "im8_cubical = cubical_persistence.fit_transform(im8_filtration)\n",
    "\n",
    "cubical_persistence.plot(im8_cubical)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It works! We can clearly see two persistent $H_1$ generators corresponding to the loops in the digit \"8\", along with a single $H_0$ generator corresponding to the connected components. \n",
    "\n",
    "As a postprocessing step, it is often convenient to rescale the persistence diagrams which can be achieved in ``giotto-tda`` as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.diagrams import Scaler\n",
    "\n",
    "scaler = Scaler()\n",
    "im8_scaled = scaler.fit_transform(im8_cubical)\n",
    "\n",
    "scaler.plot(im8_scaled)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### From persistence diagram to representation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The final step is to define a vectorial representation of the persistence diagram that can be used to obtain machine learning features. Following our example from Figure 2, we convolve our persistence diagram with a Gaussian kernel and symmetrize along the main diagonal, a procedure achieved via the [``HeatKernel``](https://giotto-ai.github.io/gtda-docs/latest/modules/generated/diagrams/representations/gtda.diagrams.HeatKernel.html#gtda.diagrams.HeatKernel) transformer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.diagrams import HeatKernel\n",
    "\n",
    "heat = HeatKernel(sigma=.15, n_bins=60, n_jobs=-1)\n",
    "im8_heat = heat.fit_transform(im8_scaled)\n",
    "\n",
    "# Visualise the heat kernel for H1\n",
    "heat.plot(im8_heat, homology_dimension_idx=1, colorscale='jet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Combining all steps as a single pipeline\n",
    "\n",
    "We've now seen how each step in Figure 2 is implemented in ``giotto-tda`` – let's combine them as a single ``scikit-learn`` pipeline:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from gtda.diagrams import Amplitude\n",
    "\n",
    "steps = [\n",
    "    (\"binarizer\", Binarizer(threshold=0.4)),\n",
    "    (\"filtration\", RadialFiltration(center=np.array([20, 6]))),\n",
    "    (\"diagram\", CubicalPersistence()),\n",
    "    (\"rescaling\", Scaler()),\n",
    "    (\"amplitude\", Amplitude(metric=\"heat\", metric_params={'sigma':0.15, 'n_bins':60}))\n",
    "]\n",
    "\n",
    "heat_pipeline = Pipeline(steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "im8_pipeline = heat_pipeline.fit_transform(im8)\n",
    "im8_pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the final step we've used the [``Amplitude``](https://giotto-ai.github.io/gtda-docs/latest/modules/generated/diagrams/features/gtda.diagrams.Amplitude.html) transformer to \"vectorize\" the persistence diagram via the heat kernel method above. In our example, this produces a vector of amplitudes $\\mathbf{a} = (a_0, a_1)$ where each amplitude $a_i$ corresponds to a given homology dimension in the persistence diagram. By extracting these feature vectors from each image, we can feed them into a machine learning classifier – let's tackle this in the next section!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building a full-blown feature extraction pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've seen how to extract topological features for a single image, let's make it more realistic and extract a wide variety of features over the whole training set. The resulting pipeline resembles the figure below, where different filtrations and vectorizations of persistence diagrams can be concatenated to produce informative feature vectors.\n",
    "\n",
    "<div style=\"text-align: center\">\n",
    "<img src='images/diagram_pipeline_images.png' width='500'>\n",
    "<p style=\"text-align: center;\"> <b>Figure 3:</b> A full-blown topological feature extraction pipeline </p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To keep things simple, we'll augment our radial filtration with a _height filtration_ $\\mathcal{H}$, defined by choosing a unit vector $v \\in \\mathbb{R}^2$ in some _direction_ and assigning values $\\mathcal{H}(p) = \\langle p, v \\rangle$ based on the distance of $p$ to the hyperplane defined by $v$. Following the article by Garin and Tauzin, we'll pick a uniform set of directions and centers for our filtrations as shown in the figure below.\n",
    "\n",
    "<div style=\"text-align: center\">\n",
    "<img src='images/directions_and_centers.png' width='250'>\n",
    "</div>\n",
    "\n",
    "We'll also generate features from persistence diagrams by using [_persistence entropy_](https://giotto-ai.github.io/gtda-docs/latest/modules/generated/diagrams/features/gtda.diagrams.PersistenceEntropy.html) and a broad set of amplitudes. Putting it all together yields the following pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import make_pipeline, make_union\n",
    "from gtda.diagrams import PersistenceEntropy\n",
    "from gtda.images import HeightFiltration\n",
    "\n",
    "direction_list = [[1, 0], [1, 1], [0, 1], [-1, 1], [-1, 0], [-1, -1], [0, -1], [1, -1]]\n",
    "\n",
    "center_list = [\n",
    "    [13, 6],\n",
    "    [6, 13],\n",
    "    [13, 13],\n",
    "    [20, 13],\n",
    "    [13, 20],\n",
    "    [6, 6],\n",
    "    [6, 20],\n",
    "    [20, 6],\n",
    "    [20, 20],\n",
    "]\n",
    "\n",
    "# Creating a list of all filtration transformer, we will be applying\n",
    "filtration_list = (\n",
    "    [\n",
    "        HeightFiltration(direction=np.array(direction), n_jobs=-1)\n",
    "        for direction in direction_list\n",
    "    ]\n",
    "    + [RadialFiltration(center=np.array(center), n_jobs=-1) for center in center_list]\n",
    ")\n",
    "\n",
    "# Creating the diagram generation pipeline\n",
    "diagram_steps = [\n",
    "    [\n",
    "        Binarizer(threshold=0.4, n_jobs=-1),\n",
    "        filtration,\n",
    "        CubicalPersistence(n_jobs=-1),\n",
    "        Scaler(n_jobs=-1),\n",
    "    ]\n",
    "    for filtration in filtration_list\n",
    "]\n",
    "\n",
    "# Listing all metrics we want to use to extract diagram amplitudes\n",
    "metric_list = [\n",
    "    {\"metric\": \"bottleneck\", \"metric_params\": {}},\n",
    "    {\"metric\": \"wasserstein\", \"metric_params\": {\"p\": 1}},\n",
    "    {\"metric\": \"wasserstein\", \"metric_params\": {\"p\": 2}},\n",
    "    {\"metric\": \"landscape\", \"metric_params\": {\"p\": 1, \"n_layers\": 1, \"n_bins\": 100}},\n",
    "    {\"metric\": \"landscape\", \"metric_params\": {\"p\": 1, \"n_layers\": 2, \"n_bins\": 100}},\n",
    "    {\"metric\": \"landscape\", \"metric_params\": {\"p\": 2, \"n_layers\": 1, \"n_bins\": 100}},\n",
    "    {\"metric\": \"landscape\", \"metric_params\": {\"p\": 2, \"n_layers\": 2, \"n_bins\": 100}},\n",
    "    {\"metric\": \"betti\", \"metric_params\": {\"p\": 1, \"n_bins\": 100}},\n",
    "    {\"metric\": \"betti\", \"metric_params\": {\"p\": 2, \"n_bins\": 100}},\n",
    "    {\"metric\": \"heat\", \"metric_params\": {\"p\": 1, \"sigma\": 1.6, \"n_bins\": 100}},\n",
    "    {\"metric\": \"heat\", \"metric_params\": {\"p\": 1, \"sigma\": 3.2, \"n_bins\": 100}},\n",
    "    {\"metric\": \"heat\", \"metric_params\": {\"p\": 2, \"sigma\": 1.6, \"n_bins\": 100}},\n",
    "    {\"metric\": \"heat\", \"metric_params\": {\"p\": 2, \"sigma\": 3.2, \"n_bins\": 100}},\n",
    "]\n",
    "\n",
    "#\n",
    "feature_union = make_union(\n",
    "    *[PersistenceEntropy(nan_fill_value=-1)]\n",
    "    + [Amplitude(**metric, n_jobs=-1) for metric in metric_list]\n",
    ")\n",
    "\n",
    "tda_union = make_union(\n",
    "    *[make_pipeline(*diagram_step, feature_union) for diagram_step in diagram_steps],\n",
    "    n_jobs=-1\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "which can be visualised using ``scikit-learn``'s nifty [HTML feature](https://scikit-learn.org/stable/modules/compose.html#visualizing-composite-estimators):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import set_config\n",
    "set_config(display='diagram')  \n",
    "\n",
    "tda_union"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's now a simple matter to run the whole pipeline:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_tda = tda_union.fit_transform(X_train)\n",
    "X_train_tda.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see we have generated $(8 + 9) \\times 2 \\times 14 = 476$ topological features per image! In general, some of these features will be highly correlated and a feature selection procedure could be used to select the most informative ones. Nevertheless, let's train a Random Forest classifier on our training set to see what kind of performance we can get:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "rf = RandomForestClassifier()\n",
    "rf.fit(X_train_tda, y_train)\n",
    "\n",
    "X_test_tda = tda_union.transform(X_test)\n",
    "rf.score(X_test_tda, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For such a small dataset, this accuracy is not too bad but accuracies above 96% can be achieved by training on the full MNIST dataset together with feature selection strategies."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using hyperparameter search with topological pipelines "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the above pipeline, we can think of our choices for the directions and centers of the filtrations as hyperparameter. To wrap up our analysis, let's see how we can run a hyperparameter search over the directions of the height filtration. We'll use a simplified pipeline to show the main steps, but note that a realistic application would involve running the search over a pipeline like the one in the previous section.\n",
    "\n",
    "As usual, we define our pipeline in terms of topological transformers and an estimator as the final step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "height_pipeline = Pipeline([\n",
    "    ('binarizer', Binarizer(threshold=0.4)),\n",
    "    ('filtration', HeightFiltration()),\n",
    "    ('diagram', CubicalPersistence()),\n",
    "    ('feature', PersistenceEntropy(nan_fill_value=-1)),\n",
    "    ('classifier', RandomForestClassifier(random_state=42))\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we can search for the best combination of directions, homology dimensions, and number of trees in our Random Forest as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "direction_list = [[1, 0], [1, 1], [0, 1], [-1, 1], [-1, 0], [-1, -1], [0, -1], [1, -1]]\n",
    "homology_dimensions_list = [[0], [1]]\n",
    "n_estimators_list = [500, 1000, 2000]\n",
    "\n",
    "param_grid = {\n",
    "    \"filtration__direction\": [np.array(direction) for direction in direction_list],\n",
    "    \"diagram__homology_dimensions\": [\n",
    "        homology_dimensions for homology_dimensions in homology_dimensions_list\n",
    "    ],\n",
    "    \"classifier__n_estimators\": [n_estimators for n_estimators in n_estimators_list],\n",
    "}\n",
    "\n",
    "grid_search = GridSearchCV(\n",
    "    estimator=height_pipeline, param_grid=param_grid, cv=3, n_jobs=-1\n",
    ")\n",
    "\n",
    "grid_search.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By looking at the best hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search.best_params_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "we see that the direction [1, 0] with homology dimension 0 produces the best features. By comparing say a \"6\" and \"9\" digit, can you think of a reason why this might be the case?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
